import numpy as np
import pandas as pd
from netZooPy.ligress.bonobo import Bonobo
from bonobo_competitors_scaled import lioness
from bonobo_competitors_scaled import spcc
from bonobo_competitors_scaled import sweet
import random
from math import floor

random.seed(10)

def simulate_bonobo_gene0(sim_mean, sim_cov, ngene = 100, nsample = 100, mix_prop = 0.05, nsample_ind = 100):
  sim_mean = sim_mean["x"].to_numpy()
  # Pick a random gene
  gene_rand = random.randint(0,sim_mean.shape[0])
  # Pick 10% random individual
  randsize = floor(nsample*mix_prop)
  ind_rand = random.sample(range(0,nsample),randsize)

  # Simulate individual gene expression and save individual mean and covariance matrix
  ind_mean = []
  ind_cor = []
  for i in range(nsample):
    expr_sim = np.random.multivariate_normal(sim_mean, sim_cov, nsample_ind).T
    if i in ind_rand:
      expr_sim[gene_rand, :] = 0
    expr_mean = np.mean(expr_sim, 1)
    expr_cor = np.corrcoef(expr_sim)
    # Replace nan by 0 and diagobal by 1 in correlaton matrix
    if i in ind_rand:
      expr_cor = np.nan_to_num(expr_cor)
      expr_cor[gene_rand, gene_rand] = 1
    ind_mean.append(expr_mean)
    ind_cor.append(expr_cor)

  expression = pd.DataFrame(np.array(ind_mean).transpose())
  
  # Set gene expression to 0 for 10% of people for a randomly chosen gene
  # expression.iloc[gene_rand, ind_rand] = 0

  bonobo_obj = Bonobo(expression)
  bonobo_obj.run_bonobo(keep_in_memory=True, output_fmt='.txt')
  ind_bonobo = bonobo_obj.bonobos
  
  bonobo_obj = Bonobo(expression)
  bonobo_obj.run_bonobo(keep_in_memory=True, output_fmt='.txt', sparsify=True)
  ind_bonobo_sparse = bonobo_obj.bonobos
  
  sweet_output = sweet(expression.T)

  ind_lioness = []
  ind_spcc = []
  for i in range(nsample):
    lioness_net = lioness(expression, i)
    ind_lioness.append(lioness_net)
    spcc_net = spcc(expression, i)
    ind_spcc.append(spcc_net)


  # Compute RMSE for each bonobo
  mse_bonobo = []
  mse_bonobo_sparse = []
  mse_lioness = []
  mse_spcc = []
  mse_sweet = []
  ind_all = np.array(range(nsample))
  for i in np.setdiff1d(ind_all, ind_rand):
    cov_error = (ind_cor[i][gene_rand,:] - ind_bonobo[i].iloc[gene_rand,:])**2
    cov_mse = np.mean(np.array(cov_error))
    mse_bonobo.append(cov_mse)
    
    cov_error = (ind_cor[i][gene_rand,:] - ind_bonobo_sparse[i].iloc[gene_rand,:])**2
    cov_mse = np.mean(np.array(cov_error))
    mse_bonobo_sparse.append(cov_mse)

    cov_error = (ind_cor[i][gene_rand,:] - ind_lioness[i][gene_rand,:])**2
    cov_mse = np.mean(np.array(cov_error))
    mse_lioness.append(cov_mse)

    cov_error = (ind_cor[i][gene_rand,:] - ind_spcc[i][gene_rand,:])**2
    cov_mse = np.mean(np.array(cov_error))
    mse_spcc.append(cov_mse)
    
    a = pd.read_csv('./SWEET_main/simulation_sweet/outputs/'+ str(i) + '_zscore.txt', sep = "\t")
    a = np.array(list(a['z_score']))
    a = np.where(np.abs(a) <= 1.96, 0, a)
    A = np.zeros(shape=(100, 100))
    A[np.triu_indices(100, k=1)] = a
    A[np.tril_indices(100, k=-1)] = a
    np.fill_diagonal(A, 1)
    # b = np.array(list(ind_cor[i][np.triu_indices(ngene, k=1)]))
    # cov_mse = np.mean((a-b)**2)
    cov_error = (ind_cor[i][gene_rand,:] - A[gene_rand,:])**2
    cov_mse = np.mean(np.array(cov_error))
    mse_sweet.append(cov_mse)

  return mse_bonobo, mse_bonobo_sparse, mse_lioness, mse_spcc, mse_sweet

